
import numpy as np
import matplotlib
from src.data_management import DataHandler, Settings
from src.methods import FixedFeature, UnconditionalMetaLearningFeature, ConditionalMetaLearningFeature
from src.plotting import plot_stuff
import time
import datetime


def main():

    # Custom's selection
    exp = 'exp_synthetic_2_clusters'  # Figure 1 left
    # exp = 'exp_synthetic_6_clusters'  # Figure 1 right
    # exp = 'exp_real_lenk'  # Figure 2 top-left
    # exp = 'exp_real_movies'  # Figure 2 top-right
    # exp = 'exp_real_jester'  # Figure 2 bottom

    if exp == 'exp_synthetic_2_clusters':
        loss_name = 'absolute'
        feature_map_name = 'linear_with_labels'
        data_settings = {'dataset': 'synthetic-regression-feature-2-CLUSTERS',
                         'n_tr_tasks': 500,
                         'n_val_tasks': 300,
                         'n_test_tasks': 100,
                         'n_all_points': 80,
                         'ts_points_pct': 0.5,
                         'n_dims': 20,
                         'noise_std': 0.1,
                         'number_clusters': 2,
                         'sparsity': 2}
        r = None
        W = None
        methods = ['ITL', 'unconditional', 'conditional']
        results = {}
    if exp == 'exp_synthetic_6_clusters':
        loss_name = 'absolute'
        feature_map_name = 'linear_with_labels'
        data_settings = {'dataset': 'synthetic-regression-feature-6-CLUSTERS',
                         'n_tr_tasks': 500,
                         'n_val_tasks': 300,
                         'n_test_tasks': 100,
                         'n_all_points': 80,
                         'ts_points_pct': 0.5,
                         'n_dims': 20,
                         'noise_std': 0.1,
                         'number_clusters': 6,
                         'sparsity': 2}
        r = None
        W = None
        methods = ['ITL', 'unconditional', 'conditional']
        results = {}
    elif exp == 'exp_real_lenk':
        loss_name = 'absolute'
        feature_map_name = 'ls_regressor'
        data_settings = {'dataset': 'lenk',
                         'n_tr_tasks': 100,
                         'n_val_tasks': 40,
                         'n_test_tasks': 30,
                         }
        r = None
        W = None
        methods = ['ITL', 'unconditional', 'conditional']
        results = {}
    elif exp == 'exp_real_movies':
        loss_name = 'absolute'
        feature_map_name = 'linear_with_labels'
        data_settings = {'dataset': 'movies',  # 943 tasks in total, n_tot = d = 939
                         'n_tr_tasks': 200,  # 400,  # 700,
                         'n_val_tasks': 100,  # 100,  # 100,
                         'n_test_tasks': 100,  # 100,  # 143,
                         'ts_points_pct': 0.25
                         }
        r = 5
        W = None
        methods = ['ITL', 'unconditional', 'conditional']
        results = {}
    elif exp == 'exp_real_jester':
        loss_name = 'absolute'
        feature_map_name = 'recommenders'
        data_settings = {'dataset': 'jester',
                         'n_tr_tasks': 250,
                         'n_val_tasks': 100,
                         'n_test_tasks': 100,
                         'ts_points_pct': 0.25
                         }
        r = 5  # in the case of recommender systems, we put the score in r
        W = None
        methods = ['ITL', 'unconditional', 'conditional']
        results = {}

    font = {'size': 26}
    matplotlib.rc('font', **font)

    # hyper-parameters range for Feature Learning
    lambda_par_range_feature = [1]  # inner regularization parameter lambda
    gamma_par_range_feature = [10 ** i for i in np.linspace(-5, 5, 14)]  # meta-step size gamma

    for curr_method in methods:

        results[curr_method] = []

    tt = time.time()

    trials = 5

    for seed in range(trials):

        print(f'SEED : ', seed, ' ---------------------------------------')
        np.random.seed(seed)
        general_settings = {'seed': seed,
                            'verbose': 1}

        settings = Settings(data_settings, 'data')
        settings.add_settings(general_settings)
        data = DataHandler(settings)

        print(f'METHOD: ', settings.data.dataset)

        for curr_method in methods:

            # print(f'method: ', curr_method)

            if curr_method == 'ITL':
                model = FixedFeature(data, np.eye(data.features_tr[0].shape[1]), lambda_par_range_feature, loss_name)
            elif curr_method == 'unconditional':
                model = UnconditionalMetaLearningFeature(data, lambda_par_range_feature, gamma_par_range_feature, loss_name)
            elif curr_method == 'conditional':
                model = ConditionalMetaLearningFeature(data, lambda_par_range_feature, gamma_par_range_feature, loss_name, feature_map_name, r, W, settings.data.dataset)

            errors = model.fit()
            results[curr_method].append(errors)

            print('%s done %5.2f' % (curr_method, time.time() - tt))

        print('seed: %d | %5.2f sec' % (seed, time.time() - tt))

    np.save(settings.data.dataset + '_' + 'temp_test_error' + '_' + str(datetime.datetime.now()).replace(':', '') + '.npy', results)
    plot_stuff(results, methods, settings.data.dataset)

    # # to load the saved results
    # read_dictionary = np.load('my_file.npy', allow_pickle='TRUE').item()
    # print(read_dictionary['hello'])  # displays "world"

    exit()


if __name__ == "__main__":

    main()
